{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Problem: Implement Mixed Precision Training Using `torch.cuda.amp`\n",
    "\n",
    "### Problem Statement\n",
    "Mixed precision training uses both 16-bit and 32-bit floating-point types to accelerate training and reduce memory usage without significantly affecting model performance. Your task is to implement mixed precision training for a deep learning model using PyTorch's `torch.cuda.amp`.\n",
    "\n",
    "### Requirements\n",
    "\n",
    "1. **Enable Mixed Precision Training**:\n",
    "   - Context manager to enable mixed precision for the forward pass.\n",
    "   - Scale gradients during backpropagation and ensure stability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Implement mixed precision training in PyTorch using torch.cuda.amp\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a simple model\n",
    "class SimpleModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SimpleModel, self).__init__()\n",
    "        self.fc = nn.Linear(10, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.fc(x)\n",
    "\n",
    "# Generate synthetic data\n",
    "X = torch.randn(1000, 10)\n",
    "y = torch.randn(1000, 1)\n",
    "dataset = TensorDataset(X, y)\n",
    "dataloader = DataLoader(dataset, batch_size=32, shuffle=True)\n",
    "\n",
    "# Initialize model, loss function, and optimizer\n",
    "model = SimpleModel().cuda()\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# Enable mixed precision training\n",
    "scaler = ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training loop\n",
    "epochs = 5\n",
    "for epoch in range(epochs):\n",
    "    for inputs, labels in dataloader:\n",
    "        inputs, labels = inputs.cuda(), labels.cuda()\n",
    "\n",
    "        # Forward pass under autocast\n",
    "        with torch.cuda.amp.autocast():\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs, labels)\n",
    "\n",
    "        # Backward pass with scaled gradients\n",
    "        optimizer.zero_grad()\n",
    "        # TODO: Set scaler\n",
    "        scaler.scale(...).backward()\n",
    "        scaler.step(...)\n",
    "        scaler.update()\n",
    "\n",
    "    print(f\"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test the model on new data\n",
    "X_test = torch.randn(5, 10).cuda()\n",
    "with torch.no_grad(), torch.cuda.amp.autocast():\n",
    "    predictions = model(X_test)\n",
    "    print(\"Predictions:\", predictions)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
